{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calibration examples\n",
    "\n",
    "This notebook shows how to calculcate different calibaration errors and apply differen calibration techniques\n",
    "\n",
    "In the alpaca library following calibration errors are implemented:\n",
    "- Expected Calibration Error (ECE)\n",
    "- Static Calibration Error (SCE)\n",
    "- Adaptive Calibration Error (ACE)\n",
    "- Thresholded Adaptive Calibration Error (TACE)\n",
    "\n",
    "As methods we use:\n",
    "- Temperature Scaling\n",
    "- Vector Scaling\n",
    "- Matrix Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "s0rU2DMvB2rL",
    "outputId": "b11f83ea-ae50-4f8e-92a4-0d585da2f5be"
   },
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "from sklearn.metrics import accuracy_score\n",
    "from scipy.special import softmax\n",
    "\n",
    "from torch.nn import functional as f\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "from alpaca.utils.datasets.builder import build_dataset\n",
    "import alpaca.calibrator as calibrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's build helper function to calculate all the calibration errors\n",
    "def compute_errors(n_bins, probs, labels, len_dataset, threshold):\n",
    "    ece = calibrator.compute_ece(n_bins, probs, labels, len_dataset)\n",
    "    sce = calibrator.compute_sce(n_bins, probs, labels)\n",
    "    ace = calibrator.compute_ace(n_bins, probs, labels)\n",
    "    tace = calibrator.compute_tace(threshold, probs, labels, n_bins)\n",
    "    errors = {\n",
    "        'ece' : ece,\n",
    "        'sce' : sce,\n",
    "        'ace' : ace,\n",
    "        'tace' : tace\n",
    "    }\n",
    "    for name, calibration_error in errors.items():\n",
    "        print(name, ' = ', calibration_error)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model training\n",
    "We showcase the calibration approaches with simple neural network and MNIST datset. To start with, we'll take the data, build neural net and train it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 784)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist = build_dataset('mnist', val_size=10_000)\n",
    "X_train, y_train = mnist.dataset('train')\n",
    "X_val, y_val = mnist.dataset('val')\n",
    "X_cal = X_train[48000:][:]\n",
    "X_train = X_train[0:48000][:]\n",
    "y_cal = y_train[48000:][:]\n",
    "y_train = y_train[0:48000][:]\n",
    "\n",
    "x_shape = (-1, 1, 28, 28)\n",
    "\n",
    "train_ds = TensorDataset(torch.FloatTensor(X_train.reshape(x_shape)), torch.LongTensor(y_train))\n",
    "val_ds = TensorDataset(torch.FloatTensor(X_val.reshape(x_shape)), torch.LongTensor(y_val))\n",
    "train_loader = DataLoader(train_ds, batch_size=512)\n",
    "val_loader = DataLoader(val_ds, batch_size=512)\n",
    "cal_ds = TensorDataset(torch.FloatTensor(X_cal.reshape(x_shape)), torch.LongTensor(y_cal))\n",
    "cal_loader = DataLoader(cal_ds, batch_size=512)\n",
    "X_val.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):   \n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "\n",
    "        self.cnn_layers = nn.Sequential(\n",
    "            nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "            nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(kernel_size=2, stride=2),\n",
    "        )\n",
    "\n",
    "        self.linear_layers = nn.Sequential(\n",
    "            nn.Linear(4 * 7 * 7, 10)\n",
    "        )\n",
    "  \n",
    "    def forward(self, x):\n",
    "        x = self.cnn_layers(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.linear_layers(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 153
    },
    "colab_type": "code",
    "id": "8qM2NebHB2ra",
    "outputId": "eaea1671-b680-489a-8cb6-68a734ab759e"
   },
   "outputs": [],
   "source": [
    "model = Net()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "..............................................................................................\n",
      "Train loss on last batch 0.5327802896499634\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.27037015557289124\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.1935376077890396\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.15540318191051483\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.13320226967334747\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.11954189091920853\n",
      "..............................................................................................\n",
      "Train loss on last batch 0.11020981520414352\n",
      "Accuracy 0.97265625\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(7):\n",
    "    for x_batch, y_batch in train_loader: # Train for one epoch\n",
    "        print('.', end='')\n",
    "        prediction = model(x_batch)\n",
    "        optimizer.zero_grad()\n",
    "        loss = criterion(prediction, y_batch)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print('\\nTrain loss on last batch', loss.item())\n",
    "\n",
    "# Check accuracy\n",
    "x_batch, y_batch = next(iter(val_loader))\n",
    "\n",
    "\n",
    "class_preds = f.softmax(model(x_batch), dim=-1).detach().numpy()\n",
    "predictions = np.argmax(class_preds, axis=-1)\n",
    "print('Accuracy', accuracy_score(predictions, y_batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calibration\n",
    "\n",
    "Calibration applies to the logit outputs of the network. Usuall pipeline is as follows:\n",
    "- Train the model with logits outputs\n",
    "- Calculate the logits for some calibration dataset (it basically a validation dataset in some sense)\n",
    "- Wrap the model with calibration model and train it on the calibration logits "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -2.3510,  -2.6608,  -4.4777,  ...,   9.7287,  -2.8388,   5.9241],\n",
       "        [  7.5125,  -5.3876,   0.5365,  ...,  -7.2053,  -0.0680,  -1.2508],\n",
       "        [ -1.4826,   6.4404,  -0.1426,  ...,  -1.9341,  -0.6942,  -1.3237],\n",
       "        ...,\n",
       "        [ -2.7429,  -6.0311,  -5.5568,  ...,  -8.2758,   5.0076,   0.1018],\n",
       "        [ -3.4007,  -6.8022,  -0.8077,  ...,   1.9691,  -1.2460,   8.5202],\n",
       "        [  8.9903, -10.3975,  -1.4172,  ...,  -7.7311,   0.1962,   0.8563]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits_list = []\n",
    "labels_list = []\n",
    "for x_batch, y_batch in cal_loader:\n",
    "    logits_list.append(model(x_batch))\n",
    "    labels_list.append(y_batch)\n",
    "logits = torch.cat(logits_list)\n",
    "labels = torch.cat(labels_list)\n",
    "logits.detach_()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Temperature Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "calibr = calibrator.ModelWithTempScaling(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModelWithTempScaling(\n",
       "  (model): Net(\n",
       "    (cnn_layers): Sequential(\n",
       "      (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (5): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (6): ReLU(inplace=True)\n",
       "      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (linear_layers): Sequential(\n",
       "      (0): Linear(in_features=196, out_features=10, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibr.scaling(logits, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_logits_list = []\n",
    "val_labels_list = []\n",
    "for x_batch, y_batch in val_loader:\n",
    "    val_logits_list.append(model(x_batch))\n",
    "    val_labels_list.append(y_batch)\n",
    "val_logits = torch.cat(val_logits_list)\n",
    "val_labels = torch.cat(val_labels_list)\n",
    "val_logits.detach_()\n",
    "probs = f.softmax(val_logits, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ece  =  tensor([0.0169])\n",
      "sce  =  tensor([0.0022])\n",
      "ace  =  tensor(0.0205)\n",
      "tace  =  tensor(0.0111)\n"
     ]
    }
   ],
   "source": [
    "compute_errors(n_bins=15, probs=probs.numpy(), labels=val_labels.numpy(),\n",
    "               len_dataset=np.shape(probs)[0], threshold=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([0.6644], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "print(calibr.temperature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ece  =  tensor([0.0066])\n",
      "sce  =  tensor([0.0012])\n",
      "ace  =  tensor(0.0144)\n",
      "tace  =  tensor(0.0059)\n"
     ]
    }
   ],
   "source": [
    "temp_scaling_probs_list = []\n",
    "for x_batch, y_batch in val_loader:\n",
    "    temp_scaling_probs_list.append(calibr.forward(x_batch))\n",
    "temp_scaling_probs = torch.cat(temp_scaling_probs_list)\n",
    "compute_errors(n_bins=15, probs=temp_scaling_probs.detach().numpy(), labels=val_labels.numpy(),\n",
    "               len_dataset=np.shape(probs)[0], threshold=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vector Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "McvVpiIqR6Nj",
    "outputId": "ed044e56-c766-457d-942a-7af91ac419f8"
   },
   "outputs": [],
   "source": [
    "calibr = calibrator.ModelWithVectScaling(model, n_classes=10).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModelWithVectScaling(\n",
       "  (model): Net(\n",
       "    (cnn_layers): Sequential(\n",
       "      (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (5): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (6): ReLU(inplace=True)\n",
       "      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (linear_layers): Sequential(\n",
       "      (0): Linear(in_features=196, out_features=10, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibr.scaling(logits, labels, lr=0.001, max_iter=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "vect_scaling_probs_list = []\n",
    "for x_batch, y_batch in val_loader:\n",
    "    vect_scaling_probs_list.append(calibr.forward(x_batch))\n",
    "vect_scaling_probs = torch.cat(vect_scaling_probs_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ece  =  tensor([0.0049])\n",
      "sce  =  tensor([0.0013])\n",
      "ace  =  tensor(0.0141)\n",
      "tace  =  tensor(0.0058)\n"
     ]
    }
   ],
   "source": [
    "compute_errors(n_bins=15, probs=vect_scaling_probs.detach().numpy(), labels=val_labels.numpy(),\n",
    "               len_dataset=np.shape(probs)[0], threshold=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([ 1.1088,  1.1993,  1.0772,  1.1848,  1.2557,  1.1746,  1.2312,  1.3807,\n",
       "         1.1934,  1.3136, -0.0117,  0.0208, -0.0524, -0.0219, -0.0316, -0.0127,\n",
       "         0.0147,  0.0682, -0.0066,  0.0332], requires_grad=True)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibr.W_and_b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Matrix Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModelWithMatrScaling(\n",
       "  (model): Net(\n",
       "    (cnn_layers): Sequential(\n",
       "      (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (2): ReLU(inplace=True)\n",
       "      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "      (4): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (5): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (6): ReLU(inplace=True)\n",
       "      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (linear_layers): Sequential(\n",
       "      (0): Linear(in_features=196, out_features=10, bias=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calibr = calibrator.ModelWithMatrScaling(model, n_classes=10).float()\n",
    "calibr.scaling(logits, labels, lr=0.0001, max_iter=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "matr_scaling_probs_list = []\n",
    "for x_batch, y_batch in val_loader:\n",
    "    matr_scaling_probs_list.append(calibr.forward(x_batch))\n",
    "matr_scaling_probs = torch.cat(matr_scaling_probs_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ece  =  tensor([0.0048])\n",
      "sce  =  tensor([0.0013])\n",
      "ace  =  tensor(0.0138)\n",
      "tace  =  tensor(0.0056)\n"
     ]
    }
   ],
   "source": [
    "compute_errors(n_bins=15, probs=matr_scaling_probs.detach().numpy(), labels=val_labels.numpy(),\n",
    "               len_dataset=np.shape(probs)[0], threshold=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hist_binning_probs = calibrator.multiclass_histogram_binning(15, logits.numpy(), labels.numpy(), val_logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ece  =  tensor([0.0074])\n",
      "sce  =  tensor([0.0015])\n",
      "ace  =  tensor(0.0150, dtype=torch.float64)\n",
      "tace  =  tensor(0.0109, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "compute_errors(n_bins=15, probs=hist_binning_probs, labels=val_labels.numpy(),\n",
    "               len_dataset=np.shape(probs)[0], threshold=0.9)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}